import numpy as np
from sklearn import datasets
from FastL1Basis import getFCT1
from Solve_mu import solve_mu_exact, solve_mu_approx, getUBneg
from sklearn.datasets import fetch_kddcup99

from sklearn.preprocessing import scale
import scipy

def get_kdd_data(npos,nneg):
    # subset="SA",
    totnpos= 6753
    X, y = fetch_kddcup99(subset="SA", percent10=True, random_state=42, return_X_y=True, as_frame=True
                          )
    y = (y != b"normal.").astype(np.int32)
    y[y==0] = -1
    # selecting last 4k. all ones are towards the end.
    features_continuous = [
        "duration",
        "src_bytes",
        "dst_bytes",
        "wrong_fragment",
        "urgent",
        "hot",
        "num_failed_logins",
        "num_compromised",
        "root_shell",
        "su_attempted",
        "num_root",
        "num_file_creations",
        "num_shells",
        "num_access_files",
        "num_outbound_cmds",
        "count",
        "srv_count",
        "serror_rate",
        "srv_serror_rate",
        "rerror_rate",
        "srv_rerror_rate",
        "same_srv_rate",
        "diff_srv_rate",
        "srv_diff_host_rate",
        "dst_host_count",
        "dst_host_srv_count",
        "dst_host_same_srv_rate",
        "dst_host_diff_srv_rate",
        "dst_host_same_src_port_rate",
        "dst_host_srv_diff_host_rate",
        "dst_host_serror_rate",
        "dst_host_srv_serror_rate",
        "dst_host_rerror_rate",
        "dst_host_srv_rerror_rate",
    ]

    X_continuous = X[features_continuous]

    # the feature num_outbound_cmds has only one value that doesn't
    # change, so drop it
    X_continuous = X_continuous.drop("num_outbound_cmds", axis="columns")

    # convert to numpy array
    X_continuous = X_continuous.to_numpy()
    leny = len(y)
    totnpos=3376

    assert npos <= totnpos
    pos_indices = np.random.choice(totnpos, npos) + (leny - totnpos)
    neg_indices = np.random.choice(leny-totnpos, nneg)
    all_indices  = np.concatenate((neg_indices,pos_indices))


    # print(y[100654-totnpos-2: 100654-totnpos+2])
    n = npos+nneg

    X_sub = np.float32(X_continuous[all_indices, :])


    # y_sub = y[100654-n: 100654]
    y_sub = y[all_indices]
    # scale the features to mean 0 and variance 1
    X = scale(X_sub)

    return X, y_sub


if __name__ == "__main__":
   allnpos = [512,1024,2048,2048]
   allnneg = [512,1024,2048,6144]
   for jj in range(len(allnneg)):
        nneg = allnneg[jj]
        npos= allnpos[jj]
        X,y= get_kdd_data(npos,nneg) #datasets.fetch_kddcup99(return_X_y=True)
        nn,d = np.shape(X)
        print("Shape of X n={}, d={}".format(nn,d))


        # yys = scipy.sparse.diags(y)

        A = np.diag(y).dot(X)

   # AA = np.float32(np.multiply(X, np.matrix(y.T)))

        # print ("shape of A is {}x{}".format(n,d))
        betastar, mu_exact = solve_mu_exact(U, 'PDLP')
        print(" npos = {} nneg = {} mu = {}".format(npos, nneg, mu_exact))

        allr1 = [32]
        for r1 in allr1:
            s = r1* 4 # np.power(r1, 3)
            U = getFCT1(A, r1, s)
            # U=A
            t = solve_mu_approx(U,r1,s)
            print("lb = {} ub = {}".format(1.0 / t, 10 / t),flush=True)

   print("done")

